#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for the transform.util classes."""

# pytype: skip-file
# pylint: disable=too-many-function-args

import collections
import hashlib
import hmac
import importlib
import logging
import math
import random
import re
import sys
import time
import unittest
import warnings
from collections.abc import Mapping
from datetime import datetime

import mock
import pytest
import pytz
from cryptography.fernet import Fernet
from cryptography.fernet import InvalidToken
from parameterized import param
from parameterized import parameterized

import apache_beam as beam
from apache_beam import GroupByKey
from apache_beam import Map
from apache_beam import WindowInto
from apache_beam.coders import coders
from apache_beam.metrics import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsList
from apache_beam.pvalue import AsSingleton
from apache_beam.runners import pipeline_context
from apache_beam.testing.synthetic_pipeline import SyntheticSource
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import SortLists
from apache_beam.testing.util import TestWindowedValue
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import contains_in_any_order
from apache_beam.testing.util import equal_to
from apache_beam.transforms import trigger
from apache_beam.transforms import util
from apache_beam.transforms import window
from apache_beam.transforms.core import FlatMapTuple
from apache_beam.transforms.trigger import AfterCount
from apache_beam.transforms.trigger import Repeatedly
from apache_beam.transforms.util import GcpSecret
from apache_beam.transforms.util import Secret
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import IntervalWindow
from apache_beam.transforms.window import Sessions
from apache_beam.transforms.window import SlidingWindows
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import typehints
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import proto_utils
from apache_beam.utils import timestamp
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.windowed_value import PANE_INFO_UNKNOWN
from apache_beam.utils.windowed_value import PaneInfo
from apache_beam.utils.windowed_value import PaneInfoTiming
from apache_beam.utils.windowed_value import WindowedValue

try:
  from google.cloud import secretmanager
except ImportError:
  secretmanager = None  # type: ignore[assignment]

warnings.filterwarnings(
    'ignore', category=FutureWarning, module='apache_beam.transform.util_test')


class _Unpicklable(object):
  def __init__(self, value):
    self.value = value

  def __getstate__(self):
    raise NotImplementedError()

  def __setstate__(self, state):
    raise NotImplementedError()


class _UnpicklableCoder(beam.coders.Coder):
  def encode(self, value):
    return str(value.value).encode()

  def decode(self, encoded):
    return _Unpicklable(int(encoded.decode()))

  def to_type_hint(self):
    return _Unpicklable

  def is_deterministic(self):
    return True


class CoGroupByKeyTest(unittest.TestCase):
  def test_co_group_by_key_on_tuple(self):
    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
                                                     ('c', 7), ('c', 8)])
      result = (pcoll_1, pcoll_2) | beam.CoGroupByKey() | SortLists
      assert_that(
          result,
          equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
                    ('c', ([4], [7, 8]))]))

  def test_co_group_by_key_on_iterable(self):
    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
                                                     ('c', 7), ('c', 8)])
      result = iter([pcoll_1, pcoll_2]) | beam.CoGroupByKey() | SortLists
      assert_that(
          result,
          equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
                    ('c', ([4], [7, 8]))]))

  def test_co_group_by_key_on_list(self):
    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
                                                     ('c', 7), ('c', 8)])
      result = [pcoll_1, pcoll_2] | beam.CoGroupByKey() | SortLists
      assert_that(
          result,
          equal_to([('a', ([1, 2], [5, 6])), ('b', ([3], [])),
                    ('c', ([4], [7, 8]))]))

  def test_co_group_by_key_on_dict(self):
    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      pcoll_2 = pipeline | 'Start 2' >> beam.Create([('a', 5), ('a', 6),
                                                     ('c', 7), ('c', 8)])
      result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists
      assert_that(
          result,
          equal_to([('a', {
              'X': [1, 2], 'Y': [5, 6]
          }), ('b', {
              'X': [3], 'Y': []
          }), ('c', {
              'X': [4], 'Y': [7, 8]
          })]))

  def test_co_group_by_key_on_dict_with_tuple_keys(self):
    with TestPipeline() as pipeline:
      key = ('a', ('b', 'c'))
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([(key, 1)])
      pcoll_2 = pipeline | 'Start 2' >> beam.Create([(key, 2)])
      result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey() | SortLists
      assert_that(result, equal_to([(key, {'X': [1], 'Y': [2]})]))

  def test_co_group_by_key_on_empty(self):
    with TestPipeline() as pipeline:
      assert_that(
          tuple() | 'EmptyTuple' >> beam.CoGroupByKey(pipeline=pipeline),
          equal_to([]),
          label='AssertEmptyTuple')
      assert_that([] | 'EmptyList' >> beam.CoGroupByKey(pipeline=pipeline),
                  equal_to([]),
                  label='AssertEmptyList')
      assert_that(
          iter([]) | 'EmptyIterable' >> beam.CoGroupByKey(pipeline=pipeline),
          equal_to([]),
          label='AssertEmptyIterable')
      assert_that({} | 'EmptyDict' >> beam.CoGroupByKey(pipeline=pipeline),
                  equal_to([]),
                  label='AssertEmptyDict')

  def test_co_group_by_key_on_one(self):
    with TestPipeline() as pipeline:
      pcoll = pipeline | beam.Create([('a', 1), ('b', 2)])
      expected = [('a', ([1], )), ('b', ([2], ))]
      assert_that((pcoll, ) | 'OneTuple' >> beam.CoGroupByKey(),
                  equal_to(expected),
                  label='AssertOneTuple')
      assert_that([pcoll] | 'OneList' >> beam.CoGroupByKey(),
                  equal_to(expected),
                  label='AssertOneList')
      assert_that(
          iter([pcoll]) | 'OneIterable' >> beam.CoGroupByKey(),
          equal_to(expected),
          label='AssertOneIterable')
      assert_that({'tag': pcoll}
                  | 'OneDict' >> beam.CoGroupByKey()
                  | beam.MapTuple(lambda k, v: (k, (v['tag'], ))),
                  equal_to(expected),
                  label='AssertOneDict')

  def test_co_group_by_key_on_unpickled(self):
    beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
    values = [_Unpicklable(i) for i in range(5)]
    with TestPipeline() as pipeline:
      xs = pipeline | beam.Create(values) | beam.WithKeys(lambda x: x)
      pcoll = ({
          'x': xs
      }
               | beam.CoGroupByKey()
               | beam.FlatMapTuple(
                   lambda k, tagged: (k.value, tagged['x'][0].value * 2)))
      expected = [0, 0, 1, 2, 2, 4, 3, 6, 4, 8]
      assert_that(pcoll, equal_to(expected))


class FakeSecret(beam.Secret):
  def __init__(self, version_name=None, should_throw=False):
    self._secret = b'aKwI2PmqYFt2p5tNKCyBS5qYmHhHsGZcyZrnZQiQ-uE='
    self._should_throw = should_throw

  def get_secret_bytes(self) -> bytes:
    if self._should_throw:
      raise RuntimeError('Exception retrieving secret')
    return self._secret


class MockNoOpDecrypt(beam.transforms.util._DecryptMessage):
  def __init__(self, hmac_key_secret, key_coder, value_coder):
    hmac_key = hmac_key_secret.get_secret_bytes()
    self.fernet_tester = Fernet(hmac_key)
    self.known_hmacs = []
    for key in ['a', 'b', 'c']:
      self.known_hmacs.append(
          hmac.new(hmac_key, key_coder.encode(key), hashlib.sha256).digest())
    super().__init__(hmac_key_secret, key_coder, value_coder)

  def process(self, element):
    final_elements = list(super().process(element))
    # Check if we're looking at the actual elements being encoded/decoded
    # There is also a gbk on assertEqual, which uses None as the key type.
    final_element_keys = [e for e in final_elements if e[0] in ['a', 'b', 'c']]
    if len(final_element_keys) == 0:
      return final_elements
    hmac_key, actual_elements = element
    if hmac_key not in self.known_hmacs:
      raise ValueError(f'GBK produced unencrypted value {hmac_key}')
    for e in actual_elements:
      try:
        self.fernet_tester.decrypt(e[0], None)
      except InvalidToken:
        raise ValueError(f'GBK produced unencrypted value {e[0]}')
      try:
        self.fernet_tester.decrypt(e[1], None)
      except InvalidToken:
        raise ValueError(f'GBK produced unencrypted value {e[1]}')

    return final_elements


class SecretTest(unittest.TestCase):
  @parameterized.expand([
      param(
          secret_string='type:GcpSecret;version_name:my_secret/versions/latest',
          secret=GcpSecret('my_secret/versions/latest')),
      param(
          secret_string='type:GcpSecret;version_name:foo',
          secret=GcpSecret('foo')),
      param(
          secret_string='type:gcpsecreT;version_name:my_secret/versions/latest',
          secret=GcpSecret('my_secret/versions/latest')),
  ])
  def test_secret_manager_parses_correctly(self, secret_string, secret):
    self.assertEqual(secret, Secret.parse_secret_option(secret_string))

  @parameterized.expand([
      param(
          secret_string='version_name:foo',
          exception_str='must contain a valid type parameter'),
      param(
          secret_string='type:gcpsecreT',
          exception_str='missing 1 required positional argument'),
      param(
          secret_string='type:gcpsecreT;version_name:foo;extra:val',
          exception_str='Invalid secret parameter extra'),
  ])
  def test_secret_manager_throws_on_invalid(self, secret_string, exception_str):
    with self.assertRaisesRegex(Exception, exception_str):
      Secret.parse_secret_option(secret_string)


class GroupByEncryptedKeyTest(unittest.TestCase):
  @classmethod
  def setUpClass(cls):
    if secretmanager is not None:
      cls.project_id = 'apache-beam-testing'
      py_version = f'_py{sys.version_info.major}{sys.version_info.minor}'
      secret_postfix = datetime.now().strftime('%m%d_%H%M%S') + py_version
      cls.secret_id = 'gbek_util_secret_tests_' + secret_postfix
      cls.client = secretmanager.SecretManagerServiceClient()
      cls.project_path = f'projects/{cls.project_id}'
      cls.secret_path = f'{cls.project_path}/secrets/{cls.secret_id}'
      try:
        cls.client.get_secret(request={'name': cls.secret_path})
      except Exception:
        cls.client.create_secret(
            request={
                'parent': cls.project_path,
                'secret_id': cls.secret_id,
                'secret': {
                    'replication': {
                        'automatic': {}
                    }
                }
            })
        cls.client.add_secret_version(
            request={
                'parent': cls.secret_path,
                'payload': {
                    'data': Secret.generate_secret_bytes()
                }
            })
      version_name = f'{cls.secret_path}/versions/latest'
      cls.gcp_secret = GcpSecret(version_name)
      cls.secret_option = f'type:GcpSecret;version_name:{version_name}'

  @classmethod
  def tearDownClass(cls):
    if secretmanager is not None:
      cls.client.delete_secret(request={'name': cls.secret_path})

  def test_gbek_fake_secret_manager_roundtrips(self):
    fakeSecret = FakeSecret()

    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret)
      assert_that(
          result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
  def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self):
    options = PipelineOptions()
    options.view_as(SetupOptions).gbek = self.secret_option

    with beam.Pipeline(options=options) as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      result = (pcoll_1) | beam.GroupByKey()
      sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1])))
      assert_that(
          sorted_result,
          equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
  def test_gbek_fake_secret_manager_actually_does_encryption(self):
    fakeSecret = FakeSecret()

    with TestPipeline('FnApiRunner') as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret)
      assert_that(
          result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
  @mock.patch('apache_beam.transforms.util.GcpSecret', FakeSecret)
  def test_gbk_actually_does_encryption(self):
    options = PipelineOptions()
    # Version of GcpSecret doesn't matter since it is replaced by FakeSecret
    options.view_as(SetupOptions).gbek = 'type:GcpSecret;version_name:Foo'

    with TestPipeline('FnApiRunner', options=options) as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)],
                                                    reshuffle=False)
      result = pcoll_1 | beam.GroupByKey()
      assert_that(
          result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  def test_gbek_fake_secret_manager_throws(self):
    fakeSecret = FakeSecret(None, True)

    with self.assertRaisesRegex(RuntimeError, r'Exception retrieving secret'):
      with TestPipeline() as pipeline:
        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                       ('b', 3), ('c', 4)])
        result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret)
        assert_that(
            result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
  def test_gbek_gcp_secret_manager_roundtrips(self):
    with TestPipeline() as pipeline:
      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                     ('b', 3), ('c', 4)])
      result = (pcoll_1) | beam.GroupByEncryptedKey(self.gcp_secret)
      assert_that(
          result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

  @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
  def test_gbek_gcp_secret_manager_throws(self):
    gcp_secret = GcpSecret('bad_path/versions/latest')

    with self.assertRaisesRegex(RuntimeError,
                                r'Failed to retrieve secret bytes'):
      with TestPipeline() as pipeline:
        pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
                                                       ('b', 3), ('c', 4)])
        result = (pcoll_1) | beam.GroupByEncryptedKey(gcp_secret)
        assert_that(
            result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))


class FakeClock(object):
  def __init__(self, now=time.time()):
    self._now = now

  def __call__(self):
    return self._now

  def sleep(self, duration):
    self._now += duration


class BatchElementsTest(unittest.TestCase):
  NUM_ELEMENTS = 10
  BATCH_SIZE = 5

  @staticmethod
  def _create_test_data():
    scientists = [
        "Einstein",
        "Darwin",
        "Copernicus",
        "Pasteur",
        "Curie",
        "Faraday",
        "Newton",
        "Bohr",
        "Galilei",
        "Maxwell"
    ]

    data = []
    for i in range(BatchElementsTest.NUM_ELEMENTS):
      index = i % len(scientists)
      data.append(scientists[index])
    return data

  def test_constant_batch(self):
    # Assumes a single bundle...
    p = TestPipeline()
    output = (
        p
        | beam.Create(range(35))
        | util.BatchElements(min_batch_size=10, max_batch_size=10)
        | beam.Map(len))
    assert_that(output, equal_to([10, 10, 10, 5]))
    res = p.run()
    res.wait_until_finish()
    metrics = res.metrics()
    results = metrics.query(MetricsFilter().with_name("batch_size"))
    self.assertEqual(len(results["distributions"]), 1)

  def test_constant_batch_no_metrics(self):
    p = TestPipeline()
    output = (
        p
        | beam.Create(range(35))
        | util.BatchElements(
            min_batch_size=10, max_batch_size=10, record_metrics=False)
        | beam.Map(len))
    assert_that(output, equal_to([10, 10, 10, 5]))
    res = p.run()
    res.wait_until_finish()
    metrics = res.metrics()
    results = metrics.query(MetricsFilter().with_name("batch_size"))
    self.assertEqual(len(results["distributions"]), 0)

  def test_grows_to_max_batch(self):
    # Assumes a single bundle, so we pin to the FnApiRunner
    with TestPipeline('FnApiRunner') as p:
      res = (
          p
          | beam.Create(range(164))
          | util.BatchElements(
              min_batch_size=1, max_batch_size=50, clock=FakeClock())
          | beam.Map(len))
      assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))

  def test_windowed_batches(self):
    # Assumes a single bundle in order, so we pin to the FnApiRunner
    with TestPipeline('FnApiRunner') as p:
      res = (
          p
          | beam.Create(range(47), reshuffle=False)
          | beam.Map(lambda t: window.TimestampedValue(t, t))
          | beam.WindowInto(window.FixedWindows(30))
          | util.BatchElements(
              min_batch_size=5, max_batch_size=10, clock=FakeClock())
          | beam.Map(len))
      assert_that(
          res,
          equal_to([
              5,
              5,
              10,
              10,  # elements in [0, 30)
              10,
              7,  # elements in [30, 47)
          ]))

  def test_global_batch_timestamps(self):
    # Assumes a single bundle, so we pin to the FnApiRunner
    with TestPipeline('FnApiRunner') as p:
      res = (
          p
          | beam.Create(range(3), reshuffle=False)
          | util.BatchElements(min_batch_size=2, max_batch_size=2)
          | beam.Map(
              lambda batch, timestamp=beam.DoFn.TimestampParam:
              (len(batch), timestamp)))
      assert_that(
          res,
          equal_to([
              (2, GlobalWindow().max_timestamp()),
              (1, GlobalWindow().max_timestamp()),
          ]))

  def test_sized_batches(self):
    with TestPipeline() as p:
      res = (
          p
          | beam.Create(
              [
                  'a',
                  'a',  # First batch.
                  'aaaaaaaaaa',  # Second batch.
                  'aaaaa',
                  'aaaaa',  # Third batch.
                  'a',
                  'aaaaaaa',
                  'a',
                  'a'  # Fourth batch.
              ],
              reshuffle=False)
          | util.BatchElements(
              min_batch_size=10, max_batch_size=10, element_size_fn=len)
          | beam.Map(lambda batch: ''.join(batch))
          | beam.Map(len))
      assert_that(res, equal_to([2, 10, 10, 10]))

  def test_sized_windowed_batches(self):
    # Assumes a single bundle, in order so we pin to the FnApiRunner
    with TestPipeline('FnApiRunner') as p:
      res = (
          p
          | beam.Create(range(1, 8), reshuffle=False)
          | beam.Map(lambda t: window.TimestampedValue('a' * t, t))
          | beam.WindowInto(window.FixedWindows(3))
          | util.BatchElements(
              min_batch_size=11,
              max_batch_size=11,
              element_size_fn=len,
              clock=FakeClock())
          | beam.Map(lambda batch: ''.join(batch)))
      assert_that(
          res,
          equal_to([
              'a' * (1 + 2),  # Elements in [1, 3)
              'a' * (3 + 4),  # Elements in [3, 6)
              'a' * 5,
              'a' * 6,  # Elements in [6, 9)
              'a' * 7,
          ]))

  def test_target_duration(self):
    clock = FakeClock()
    batch_estimator = util._BatchSizeEstimator(
        target_batch_overhead=None, target_batch_duration_secs=10, clock=clock)
    batch_duration = lambda batch_size: 1 + .7 * batch_size
    # 14 * .7 is as close as we can get to 10 as possible.
    expected_sizes = [1, 2, 4, 8, 14, 14, 14]
    actual_sizes = []
    for _ in range(len(expected_sizes)):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        clock.sleep(batch_duration(actual_sizes[-1]))
    self.assertEqual(expected_sizes, actual_sizes)

  def test_target_duration_including_fixed_cost(self):
    clock = FakeClock()
    batch_estimator = util._BatchSizeEstimator(
        target_batch_overhead=None,
        target_batch_duration_secs_including_fixed_cost=10,
        clock=clock)
    batch_duration = lambda batch_size: 1 + .7 * batch_size
    # 1 + 14 * .7 is as close as we can get to 10 as possible.
    expected_sizes = [1, 2, 4, 8, 12, 12, 12]
    actual_sizes = []
    for _ in range(len(expected_sizes)):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        clock.sleep(batch_duration(actual_sizes[-1]))
    self.assertEqual(expected_sizes, actual_sizes)

  def test_target_overhead(self):
    clock = FakeClock()
    batch_estimator = util._BatchSizeEstimator(
        target_batch_overhead=.05, target_batch_duration_secs=None, clock=clock)
    batch_duration = lambda batch_size: 1 + .7 * batch_size
    # At 27 items, a batch takes ~20 seconds with 5% (~1 second) overhead.
    expected_sizes = [1, 2, 4, 8, 16, 27, 27, 27]
    actual_sizes = []
    for _ in range(len(expected_sizes)):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        clock.sleep(batch_duration(actual_sizes[-1]))
    self.assertEqual(expected_sizes, actual_sizes)

  def test_variance(self):
    clock = FakeClock()
    variance = 0.25
    batch_estimator = util._BatchSizeEstimator(
        target_batch_overhead=.05,
        target_batch_duration_secs=None,
        variance=variance,
        clock=clock)
    batch_duration = lambda batch_size: 1 + .7 * batch_size
    expected_target = 27
    actual_sizes = []
    for _ in range(util._BatchSizeEstimator._MAX_DATA_POINTS - 1):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        clock.sleep(batch_duration(actual_sizes[-1]))
    # Check that we're testing a good range of values.
    stable_set = set(actual_sizes[-20:])
    self.assertGreater(len(stable_set), 3)
    self.assertGreater(
        min(stable_set), expected_target - expected_target * variance)
    self.assertLess(
        max(stable_set), expected_target + expected_target * variance)

  def test_ignore_first_n_batch_size(self):
    clock = FakeClock()
    batch_estimator = util._BatchSizeEstimator(
        clock=clock, ignore_first_n_seen_per_batch_size=2)

    expected_sizes = [
        1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 64, 64, 64
    ]
    actual_sizes = []
    for i in range(len(expected_sizes)):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        if i % 3 == 2:
          clock.sleep(0.01)
        else:
          clock.sleep(1)

    self.assertEqual(expected_sizes, actual_sizes)

    # Check we only record the third timing.
    expected_data_batch_sizes = [1, 2, 4, 8, 16, 32, 64]
    actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
    self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
    expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
    for i in range(len(expected_data_timing)):
      self.assertAlmostEqual(
          expected_data_timing[i], batch_estimator._data[i][1])

  def test_ignore_next_timing(self):
    clock = FakeClock()
    batch_estimator = util._BatchSizeEstimator(clock=clock)
    batch_estimator.ignore_next_timing()

    expected_sizes = [1, 1, 2, 4, 8, 16]
    actual_sizes = []
    for i in range(len(expected_sizes)):
      actual_sizes.append(batch_estimator.next_batch_size())
      with batch_estimator.record_time(actual_sizes[-1]):
        if i == 0:
          clock.sleep(1)
        else:
          clock.sleep(0.01)

    self.assertEqual(expected_sizes, actual_sizes)

    # Check the first record_time was skipped.
    expected_data_batch_sizes = [1, 2, 4, 8, 16]
    actual_data_batch_sizes = [x[0] for x in batch_estimator._data]
    self.assertEqual(expected_data_batch_sizes, actual_data_batch_sizes)
    expected_data_timing = [0.01, 0.01, 0.01, 0.01, 0.01]
    for i in range(len(expected_data_timing)):
      self.assertAlmostEqual(
          expected_data_timing[i], batch_estimator._data[i][1])

  def _run_regression_test(self, linear_regression_fn, test_outliers):
    xs = [random.random() for _ in range(10)]
    ys = [2 * x + 1 for x in xs]
    a, b = linear_regression_fn(xs, ys)
    self.assertAlmostEqual(a, 1)
    self.assertAlmostEqual(b, 2)

    xs = [1 + random.random() for _ in range(100)]
    ys = [7 * x + 5 + 0.01 * random.random() for x in xs]
    a, b = linear_regression_fn(xs, ys)
    self.assertAlmostEqual(a, 5, delta=0.02)
    self.assertAlmostEqual(b, 7, delta=0.02)

    # Test repeated xs
    xs = [1 + random.random()] * 100
    ys = [7 * x + 5 + 0.01 * random.random() for x in xs]
    a, b = linear_regression_fn(xs, ys)
    self.assertAlmostEqual(a, 0, delta=0.02)
    self.assertAlmostEqual(b, sum(ys) / (len(ys) * xs[0]), delta=0.02)

    if test_outliers:
      xs = [1 + random.random() for _ in range(100)]
      ys = [2 * x + 1 for x in xs]
      a, b = linear_regression_fn(xs, ys)
      self.assertAlmostEqual(a, 1)
      self.assertAlmostEqual(b, 2)

      # An outlier or two doesn't affect the result.
      for _ in range(2):
        xs += [10]
        ys += [30]
        a, b = linear_regression_fn(xs, ys)
        self.assertAlmostEqual(a, 1)
        self.assertAlmostEqual(b, 2)

      # But enough of them, and they're no longer outliers.
      xs += [10] * 10
      ys += [30] * 10
      a, b = linear_regression_fn(xs, ys)
      self.assertLess(a, 0.5)
      self.assertGreater(b, 2.5)

  def test_no_numpy_regression(self):
    self._run_regression_test(
        util._BatchSizeEstimator.linear_regression_no_numpy, False)

  def test_numpy_regression(self):
    try:
      # pylint: disable=wrong-import-order, wrong-import-position
      import numpy as _
    except ImportError:
      self.skipTest('numpy not available')
    self._run_regression_test(
        util._BatchSizeEstimator.linear_regression_numpy, True)

  def test_stateful_constant_batch(self):
    # Assumes a single bundle, so we pin to the FnApiRunner
    p = TestPipeline('FnApiRunner')
    output = (
        p
        | beam.Create(range(35))
        | util.BatchElements(
            min_batch_size=10, max_batch_size=10, max_batch_duration_secs=100)
        | beam.Map(len))
    assert_that(output, equal_to([10, 10, 10, 5]))
    res = p.run()
    res.wait_until_finish()

  def test_stateful_in_global_window(self):
    with TestPipeline() as pipeline:
      collection = pipeline \
                   | beam.Create(
                     BatchElementsTest._create_test_data()) \
                   | util.BatchElements(
                     min_batch_size=BatchElementsTest.BATCH_SIZE,
                     max_batch_size=BatchElementsTest.BATCH_SIZE,
                     max_batch_duration_secs=100)
      num_batches = collection | beam.combiners.Count.Globally()
      assert_that(
          num_batches,
          equal_to([
              int(
                  math.ceil(
                      BatchElementsTest.NUM_ELEMENTS /
                      BatchElementsTest.BATCH_SIZE))
          ]))

  def test_stateful_buffering_timer_in_fixed_window_streaming(self):
    window_duration = 6
    max_buffering_duration_secs = 100

    start_time = timestamp.Timestamp(0)
    test_stream = (
        TestStream().add_elements([
            TimestampedValue(value, start_time + i)
            for i, value in enumerate(BatchElementsTest._create_test_data())
        ]).advance_processing_time(150).advance_watermark_to(
            start_time + window_duration).advance_watermark_to(
                start_time + window_duration +
                1).advance_watermark_to_infinity())

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0).
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | "fixed window" >> WindowInto(FixedWindows(window_duration))
          | util.BatchElements(
              min_batch_size=BatchElementsTest.BATCH_SIZE,
              max_batch_size=BatchElementsTest.BATCH_SIZE,
              max_batch_duration_secs=max_buffering_duration_secs,
              clock=fake_clock)
          | "count elements in batch" >> Map(lambda x: (None, len(x)))
          | GroupByKey()
          | "global window" >> WindowInto(GlobalWindows())
          | FlatMapTuple(lambda k, vs: vs))

      # Window duration is 6 and batch size is 5, so output batch size
      # should be 5 (flush because of batch size reached).
      expected_0 = 5
      # There is only one element left in the window so batch size
      # should be 1 (flush because of max buffering duration reached).
      expected_1 = 1
      # Collection has 10 elements, there are only 4 left, so batch size should
      # be 4 (flush because of end of window reached).
      expected_2 = 4
      assert_that(
          num_elements_per_batch,
          equal_to([expected_0, expected_1, expected_2]),
          "assert2")

  def test_stateful_buffering_timer_in_global_window_streaming(self):
    max_buffering_duration_secs = 42

    start_time = timestamp.Timestamp(0)
    test_stream = TestStream().advance_watermark_to(start_time)
    for i, value in enumerate(BatchElementsTest._create_test_data()):
      test_stream.add_elements(
          [TimestampedValue(value, start_time + i)]) \
        .advance_processing_time(5)
    test_stream.advance_watermark_to(
        start_time + BatchElementsTest.NUM_ELEMENTS + 1) \
      .advance_watermark_to_infinity()

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # Set a batch size larger than the total number of elements.
      # Since we're in a global window, we would have been waiting
      # for all the elements to arrive without the buffering time limit.
      batch_size = BatchElementsTest.NUM_ELEMENTS * 2

      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0). Since the fake clock never really advances during
      # the pipeline execution, meaning that the timer is always set to the same
      # value, the timer will be fired on every element after the first firing.
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | WindowInto(
              GlobalWindows(),
              trigger=Repeatedly(AfterCount(1)),
              accumulation_mode=trigger.AccumulationMode.DISCARDING)
          | util.BatchElements(
              min_batch_size=batch_size,
              max_batch_size=batch_size,
              max_batch_duration_secs=max_buffering_duration_secs,
              clock=fake_clock)
          | 'count elements in batch' >> Map(lambda x: (None, len(x)))
          | GroupByKey()
          | FlatMapTuple(lambda k, vs: vs))

      # We will flush twice when the max buffering duration is reached and when
      # the global window ends.
      assert_that(num_elements_per_batch, equal_to([9, 1]))

  def test_stateful_grows_to_max_batch(self):
    # Assumes a single bundle, so we pin to the FnApiRunner
    with TestPipeline('FnApiRunner') as p:
      res = (
          p
          | beam.Create(range(164))
          | util.BatchElements(
              min_batch_size=1,
              max_batch_size=50,
              max_batch_duration_secs=100,
              clock=FakeClock())
          | beam.Map(len))
      assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))


class IdentityWindowTest(unittest.TestCase):
  def test_window_preserved(self):
    expected_timestamp = timestamp.Timestamp(5)
    expected_window = window.IntervalWindow(1.0, 2.0)

    class AddWindowDoFn(beam.DoFn):
      def process(self, element):
        yield WindowedValue(element, expected_timestamp, [expected_window])

    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      expected_windows = [
          TestWindowedValue(kv, expected_timestamp, [expected_window])
          for kv in data
      ]
      before_identity = (
          pipeline
          | 'start' >> beam.Create(data)
          | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
      assert_that(
          before_identity,
          equal_to(expected_windows),
          label='before_identity',
          reify_windows=True)
      after_identity = (
          before_identity
          | 'window' >> beam.WindowInto(
              beam.transforms.util._IdentityWindowFn(
                  coders.IntervalWindowCoder())))
      assert_that(
          after_identity,
          equal_to(expected_windows),
          label='after_identity',
          reify_windows=True)

  def test_no_window_context_fails(self):
    expected_timestamp = timestamp.Timestamp(5)
    # Assuming the default window function is window.GlobalWindows.
    expected_window = window.GlobalWindow()

    class AddTimestampDoFn(beam.DoFn):
      def process(self, element):
        yield window.TimestampedValue(element, expected_timestamp)

    with self.assertRaisesRegex(Exception, r'.*window.*None.*add_timestamps2'):
      with TestPipeline() as pipeline:
        data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
        expected_windows = [
            TestWindowedValue(kv, expected_timestamp, [expected_window])
            for kv in data
        ]
        before_identity = (
            pipeline
            | 'start' >> beam.Create(data)
            | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
        assert_that(
            before_identity,
            equal_to(expected_windows),
            label='before_identity',
            reify_windows=True)
        _ = (
            before_identity
            | 'window' >> beam.WindowInto(
                beam.transforms.util._IdentityWindowFn(
                    coders.GlobalWindowCoder()))
            # This DoFn will return TimestampedValues, making
            # WindowFn.AssignContext passed to IdentityWindowFn
            # contain a window of None. IdentityWindowFn should
            # raise an exception.
            | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))


class ReshuffleTest(unittest.TestCase):
  def test_reshuffle_contents_unchanged(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
      result = (pipeline | beam.Create(data) | beam.Reshuffle())
      assert_that(result, equal_to(data))

  def test_reshuffle_contents_unchanged_with_buckets(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
      buckets = 2
      result = (pipeline | beam.Create(data) | beam.Reshuffle(buckets))
      assert_that(result, equal_to(data))

  def test_reshuffle_contents_unchanged_with_wrong_buckets(self):
    wrong_buckets = [0, -1, "wrong", 2.5]
    for wrong_bucket in wrong_buckets:
      with self.assertRaisesRegex(ValueError,
                                  'If `num_buckets` is set, it has to be an '
                                  'integer greater than 0, got %s' %
                                  wrong_bucket):
        beam.Reshuffle(wrong_bucket)

  def test_reshuffle_after_gbk_contents_unchanged(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
      expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]

      after_gbk = (
          pipeline
          | beam.Create(data)
          | beam.GroupByKey()
          | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
      assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
      after_reshuffle = after_gbk | beam.Reshuffle()
      assert_that(
          after_reshuffle, equal_to(expected_result), label='after_reshuffle')

  def test_reshuffle_timestamps_unchanged(self):
    with TestPipeline() as pipeline:
      timestamp = 5
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
      expected_result = [
          TestWindowedValue(v, timestamp, [GlobalWindow()]) for v in data
      ]
      before_reshuffle = (
          pipeline
          | 'start' >> beam.Create(data)
          | 'add_timestamp' >>
          beam.Map(lambda v: beam.window.TimestampedValue(v, timestamp)))
      assert_that(
          before_reshuffle,
          equal_to(expected_result),
          label='before_reshuffle',
          reify_windows=True)
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      assert_that(
          after_reshuffle,
          equal_to(expected_result),
          label='after_reshuffle',
          reify_windows=True)

  def test_reshuffle_windows_unchanged(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      expected_data = [
          TestWindowedValue(
              v,
              t - .001, [w],
              pane_info=PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0))
          for (v, t, w) in [((1, contains_in_any_order([2, 1])), 4.0,
                             IntervalWindow(1.0, 4.0)), (
                                 (2, contains_in_any_order([2, 1])), 4.0,
                                 IntervalWindow(1.0, 4.0)), ((
                                     3, [1]), 3.0, IntervalWindow(1.0, 3.0)), ((
                                         1,
                                         [4]), 6.0, IntervalWindow(4.0, 6.0))]
      ]
      before_reshuffle = (
          pipeline
          | 'start' >> beam.Create(data)
          | 'add_timestamp' >>
          beam.Map(lambda v: beam.window.TimestampedValue(v, v[1]))
          | 'window' >> beam.WindowInto(Sessions(gap_size=2))
          | 'group_by_key' >> beam.GroupByKey())
      assert_that(
          before_reshuffle,
          equal_to(expected_data),
          label='before_reshuffle',
          reify_windows=True)
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      assert_that(
          after_reshuffle,
          equal_to(expected_data),
          label='after reshuffle',
          reify_windows=True)

  def test_reshuffle_window_fn_preserved(self):
    any_order = contains_in_any_order
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]

      expected_windows = [
          TestWindowedValue(v, t, [w])
          for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), ((
              2, 1), 1.0, IntervalWindow(1.0, 3.0)), ((
                  3, 1), 1.0, IntervalWindow(1.0, 3.0)), ((
                      1, 2), 2.0, IntervalWindow(2.0, 4.0)), (
                          (2, 2), 2.0,
                          IntervalWindow(2.0, 4.0)), ((1, 4), 4.0,
                                                      IntervalWindow(4.0, 6.0))]
      ]
      expected_merged_windows = [
          TestWindowedValue(
              v,
              t - .001, [w],
              pane_info=PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0))
          for (v, t,
               w) in [((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), (
                   (2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), (
                       (3, [1]), 3.0,
                       IntervalWindow(1.0, 3.0)), ((1, [4]), 6.0,
                                                   IntervalWindow(4.0, 6.0))]
      ]
      before_reshuffle = (
          pipeline
          | 'start' >> beam.Create(data)
          | 'add_timestamp' >> beam.Map(lambda v: TimestampedValue(v, v[1]))
          | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
      assert_that(
          before_reshuffle,
          equal_to(expected_windows),
          label='before_reshuffle',
          reify_windows=True)
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      assert_that(
          after_reshuffle,
          equal_to(expected_windows),
          label='after_reshuffle',
          reify_windows=True)
      after_group = after_reshuffle | beam.GroupByKey()
      assert_that(
          after_group,
          equal_to(expected_merged_windows),
          label='after_group',
          reify_windows=True)

  def test_reshuffle_global_window(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
      before_reshuffle = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(GlobalWindows())
          | beam.GroupByKey()
          | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
      assert_that(
          before_reshuffle, equal_to(expected_data), label='before_reshuffle')
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      assert_that(
          after_reshuffle, equal_to(expected_data), label='after reshuffle')

  def test_reshuffle_sliding_window(self):
    with TestPipeline() as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      window_size = 2
      expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size
      before_reshuffle = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(SlidingWindows(size=window_size, period=1))
          | beam.GroupByKey()
          | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
      assert_that(
          before_reshuffle, equal_to(expected_data), label='before_reshuffle')
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      # If Reshuffle applies the sliding window function a second time there
      # should be extra values for each key.
      assert_that(
          after_reshuffle, equal_to(expected_data), label='after reshuffle')

  def test_reshuffle_streaming_global_window(self):
    options = PipelineOptions()
    options.view_as(StandardOptions).streaming = True
    with TestPipeline(options=options) as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
      before_reshuffle = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(GlobalWindows())
          | beam.GroupByKey()
          | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
      assert_that(
          before_reshuffle, equal_to(expected_data), label='before_reshuffle')
      after_reshuffle = before_reshuffle | beam.Reshuffle()
      assert_that(
          after_reshuffle, equal_to(expected_data), label='after reshuffle')

  def test_reshuffle_streaming_global_window_with_buckets(self):
    options = PipelineOptions()
    options.view_as(StandardOptions).streaming = True
    with TestPipeline(options=options) as pipeline:
      data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
      expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
      buckets = 2
      before_reshuffle = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(GlobalWindows())
          | beam.GroupByKey()
          | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
      assert_that(
          before_reshuffle, equal_to(expected_data), label='before_reshuffle')
      after_reshuffle = before_reshuffle | beam.Reshuffle(buckets)
      assert_that(
          after_reshuffle, equal_to(expected_data), label='after reshuffle')

  @parameterized.expand([
      param(compat_version=None),
      param(compat_version="2.64.0"),
  ])
  def test_reshuffle_custom_window_preserves_metadata(self, compat_version):
    """Tests that Reshuffle preserves pane info."""
    from apache_beam.coders import typecoders
    typecoders.registry.force_dill_deterministic_coders = True
    element_count = 12
    timestamp_value = timestamp.Timestamp(0)
    l = [
        TimestampedValue(("key", i), timestamp_value)
        for i in range(element_count)
    ]

    expected_timestamp = GlobalWindow().max_timestamp()
    expected = [
        TestWindowedValue(
            ('key', [0, 1, 2]),
            expected_timestamp,
            [GlobalWindow()],
            pane_info=PaneInfo(
                is_first=True,
                is_last=False,
                timing=PaneInfoTiming.EARLY,  # 0
                index=0,
                nonspeculative_index=-1)),
        TestWindowedValue(
            ('key', [3, 4, 5]),
            expected_timestamp,
            [GlobalWindow()],
            pane_info=PaneInfo(
                is_first=False,
                is_last=False,
                timing=PaneInfoTiming.EARLY,  # 0
                index=1,
                nonspeculative_index=-1)),
        TestWindowedValue(
            ('key', [6, 7, 8]),
            expected_timestamp,
            [GlobalWindow()],
            pane_info=PaneInfo(
                is_first=False,
                is_last=False,
                timing=PaneInfoTiming.EARLY,  # 0
                index=2,
                nonspeculative_index=-1)),
        TestWindowedValue(
            ('key', [9, 10, 11]),
            expected_timestamp,
            [GlobalWindow()],
            pane_info=PaneInfo(
                is_first=False,
                is_last=False,
                timing=PaneInfoTiming.EARLY,  # 0
                index=3,
                nonspeculative_index=-1))
    ] if compat_version is None else ([
        TestWindowedValue(('key', [0, 1, 2]),
                          expected_timestamp, [GlobalWindow()],
                          PANE_INFO_UNKNOWN),
        TestWindowedValue(('key', [3, 4, 5]),
                          expected_timestamp, [GlobalWindow()],
                          PANE_INFO_UNKNOWN),
        TestWindowedValue(('key', [6, 7, 8]),
                          expected_timestamp, [GlobalWindow()],
                          PANE_INFO_UNKNOWN),
        TestWindowedValue(('key', [9, 10, 11]),
                          expected_timestamp, [GlobalWindow()],
                          PANE_INFO_UNKNOWN)
    ])
    options = PipelineOptions(update_compatibility_version=compat_version)
    options.view_as(StandardOptions).streaming = True

    with beam.Pipeline(options=options) as p:
      stream_source = (
          TestStream().advance_watermark_to(0).advance_processing_time(
              100).add_elements(l[:element_count // 4]).advance_processing_time(
                  100).advance_watermark_to(100).add_elements(
                      l[element_count // 4:2 * element_count // 4]).
          advance_processing_time(100).advance_watermark_to(200).add_elements(
              l[2 * element_count // 4:3 * element_count //
                4]).advance_processing_time(
                    100).advance_watermark_to(300).add_elements(
                        l[3 * element_count // 4:]).advance_processing_time(
                            100).advance_watermark_to_infinity())
      grouped = (
          p | stream_source
          | "Rewindow" >> beam.WindowInto(
              beam.window.GlobalWindows(),
              trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
              accumulation_mode=trigger.AccumulationMode.DISCARDING)
          | beam.GroupByKey())

      after_reshuffle = (grouped | 'Reshuffle' >> beam.Reshuffle())

      assert_that(
          after_reshuffle,
          equal_to(expected),
          label='CheckMetadataPreserved',
          reify_windows=True)
    typecoders.registry.force_dill_deterministic_coders = False

  @parameterized.expand([
      param(compat_version=None),
      param(compat_version="2.64.0"),
  ])
  def test_reshuffle_default_window_preserves_metadata(self, compat_version):
    """Tests that Reshuffle preserves timestamp, window, and pane info
    metadata."""
    from apache_beam.coders import typecoders
    typecoders.registry.force_dill_deterministic_coders = True
    no_firing = PaneInfo(
        is_first=True,
        is_last=True,
        timing=PaneInfoTiming.UNKNOWN,
        index=0,
        nonspeculative_index=0)

    on_time_only = PaneInfo(
        is_first=True,
        is_last=True,
        timing=PaneInfoTiming.ON_TIME,
        index=0,
        nonspeculative_index=0)

    late_firing = PaneInfo(
        is_first=False,
        is_last=False,
        timing=PaneInfoTiming.LATE,
        index=1,
        nonspeculative_index=1)

    # Portable runners may not have the same level of precision on timestamps -
    # this gets the largest supported timestamp with the extra non-supported
    # bits truncated
    gt = GlobalWindow().max_timestamp()
    truncated_gt = gt - (gt % 0.001)

    expected_preserved = [
        TestWindowedValue('a', MIN_TIMESTAMP, [GlobalWindow()], no_firing),
        TestWindowedValue(
            'b', timestamp.Timestamp(0), [GlobalWindow()], on_time_only),
        TestWindowedValue(
            'c', timestamp.Timestamp(33), [GlobalWindow()], late_firing),
        TestWindowedValue('d', truncated_gt, [GlobalWindow()], no_firing)
    ]

    expected_not_preserved = [
        TestWindowedValue(
            'a', MIN_TIMESTAMP, [GlobalWindow()], PANE_INFO_UNKNOWN),
        TestWindowedValue(
            'b', timestamp.Timestamp(0), [GlobalWindow()], PANE_INFO_UNKNOWN),
        TestWindowedValue(
            'c', timestamp.Timestamp(33), [GlobalWindow()], PANE_INFO_UNKNOWN),
        TestWindowedValue(
            'd', truncated_gt, [GlobalWindow()], PANE_INFO_UNKNOWN)
    ]

    expected = (
        expected_preserved
        if compat_version is None else expected_not_preserved)

    options = PipelineOptions(update_compatibility_version=compat_version)
    with TestPipeline(options=options) as pipeline:
      # Create windowed values with specific metadata
      elements = [
          WindowedValue('a', MIN_TIMESTAMP, [GlobalWindow()], no_firing),
          WindowedValue(
              'b', timestamp.Timestamp(0), [GlobalWindow()], on_time_only),
          WindowedValue(
              'c', timestamp.Timestamp(33), [GlobalWindow()], late_firing),
          WindowedValue('d', truncated_gt, [GlobalWindow()], no_firing)
      ]

      after_reshuffle = (
          pipeline
          | 'Create' >> beam.Create(elements)
          | 'Reshuffle' >> beam.Reshuffle())

      assert_that(
          after_reshuffle,
          equal_to(expected),
          label='CheckMetadataPreserved',
          reify_windows=True)
    typecoders.registry.force_dill_deterministic_coders = False

  @pytest.mark.it_validatesrunner
  def test_reshuffle_preserves_timestamps(self):
    with TestPipeline() as pipeline:

      # Create a PCollection and assign each element with a different timestamp.
      before_reshuffle = (
          pipeline
          | beam.Create([
              {
                  'name': 'foo', 'timestamp': MIN_TIMESTAMP
              },
              {
                  'name': 'foo', 'timestamp': 0
              },
              {
                  'name': 'bar', 'timestamp': 33
              },
              {
                  'name': 'bar', 'timestamp': 0
              },
          ])
          | beam.Map(
              lambda element: beam.window.TimestampedValue(
                  element, element['timestamp'])))

      # Reshuffle the PCollection above and assign the timestamp of an element
      # to that element again.
      after_reshuffle = before_reshuffle | beam.Reshuffle()

      # Given an element, emits a string which contains the timestamp and the
      # name field of the element.
      def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
        t = str(timestamp)
        if timestamp == MIN_TIMESTAMP:
          t = 'MIN_TIMESTAMP'
        elif timestamp == MAX_TIMESTAMP:
          t = 'MAX_TIMESTAMP'
        return '{} - {}'.format(t, element['name'])

      # Combine each element in before_reshuffle with its timestamp.
      formatted_before_reshuffle = (
          before_reshuffle
          | "Get before_reshuffle timestamp" >> beam.Map(format_with_timestamp))

      # Combine each element in after_reshuffle with its timestamp.
      formatted_after_reshuffle = (
          after_reshuffle
          | "Get after_reshuffle timestamp" >> beam.Map(format_with_timestamp))

      expected_data = [
          'MIN_TIMESTAMP - foo',
          'Timestamp(0) - foo',
          'Timestamp(33) - bar',
          'Timestamp(0) - bar'
      ]

      # Can't compare formatted_before_reshuffle and formatted_after_reshuffle
      # directly, because they are deferred PCollections while equal_to only
      # takes a concrete argument.
      assert_that(
          formatted_before_reshuffle,
          equal_to(expected_data),
          label="formatted_before_reshuffle")
      assert_that(
          formatted_after_reshuffle,
          equal_to(expected_data),
          label="formatted_after_reshuffle")

  def reshuffle_unpicklable_in_global_window_helper(
      self, update_compatibility_version=None):
    with TestPipeline(options=PipelineOptions(
        update_compatibility_version=update_compatibility_version)) as pipeline:
      data = [_Unpicklable(i) for i in range(5)]
      expected_data = [0, 10, 20, 30, 40]
      result = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(GlobalWindows())
          | beam.Reshuffle()
          | beam.Map(lambda u: u.value * 10))
      assert_that(result, equal_to(expected_data))

  def test_reshuffle_unpicklable_in_global_window(self):
    beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

    self.reshuffle_unpicklable_in_global_window_helper()
    # An exception is raised when running reshuffle on unpicklable objects
    # prior to 2.64.0
    self.assertRaises(
        RuntimeError,
        self.reshuffle_unpicklable_in_global_window_helper,
        "2.63.0")

  def reshuffle_unpicklable_in_non_global_window_helper(
      self, update_compatibility_version=None):
    with TestPipeline(options=PipelineOptions(
        update_compatibility_version=update_compatibility_version)) as pipeline:
      data = [_Unpicklable(i) for i in range(5)]
      expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
      result = (
          pipeline
          | beam.Create(data)
          | beam.WindowInto(window.SlidingWindows(size=3, period=1))
          | beam.Reshuffle()
          | beam.Map(lambda u: u.value * 10))
      assert_that(result, equal_to(expected_data))

  def test_reshuffle_unpicklable_in_non_global_window(self):
    beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

    self.reshuffle_unpicklable_in_non_global_window_helper()
    # An exception is raised when running reshuffle on unpicklable objects
    # prior to 2.64.0
    self.assertRaises(
        RuntimeError,
        self.reshuffle_unpicklable_in_non_global_window_helper,
        "2.63.0")


class WithKeysTest(unittest.TestCase):
  def setUp(self):
    self.l = [1, 2, 3]

  def test_constant_k(self):
    with TestPipeline() as p:
      pc = p | beam.Create(self.l)
      with_keys = pc | util.WithKeys('k')
      assert_that(with_keys, equal_to([('k', 1), ('k', 2), ('k', 3)],
                                      ))

  def test_callable_k(self):
    with TestPipeline() as p:
      pc = p | beam.Create(self.l)
      with_keys = pc | util.WithKeys(lambda x: x * x)
      assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))

  @staticmethod
  def _test_args_kwargs_fn(x, multiply, subtract):
    return x * multiply - subtract

  def test_args_kwargs_k(self):
    with TestPipeline() as p:
      pc = p | beam.Create(self.l)
      with_keys = pc | util.WithKeys(
          WithKeysTest._test_args_kwargs_fn, 2, subtract=1)
      assert_that(with_keys, equal_to([(1, 1), (3, 2), (5, 3)]))

  def test_sideinputs(self):
    with TestPipeline() as p:
      pc = p | beam.Create(self.l)
      si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3]))
      si2 = AsSingleton(p | "side input 2" >> beam.Create([10]))
      with_keys = pc | util.WithKeys(
          lambda x, the_list, the_singleton: x + sum(the_list) + the_singleton,
          si1,
          the_singleton=si2)
      assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)]))


class GroupIntoBatchesTest(unittest.TestCase):
  NUM_ELEMENTS = 10
  BATCH_SIZE = 5

  @staticmethod
  def _create_test_data():
    scientists = [
        "Einstein",
        "Darwin",
        "Copernicus",
        "Pasteur",
        "Curie",
        "Faraday",
        "Newton",
        "Bohr",
        "Galilei",
        "Maxwell"
    ]

    data = []
    for i in range(GroupIntoBatchesTest.NUM_ELEMENTS):
      index = i % len(scientists)
      data.append(("key", scientists[index]))
    return data

  def test_in_global_window(self):
    with TestPipeline() as pipeline:
      collection = pipeline \
                   | beam.Create(GroupIntoBatchesTest._create_test_data()) \
                   | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
      num_batches = collection | beam.combiners.Count.Globally()
      assert_that(
          num_batches,
          equal_to([
              int(
                  math.ceil(
                      GroupIntoBatchesTest.NUM_ELEMENTS /
                      GroupIntoBatchesTest.BATCH_SIZE))
          ]))

  def test_in_global_window_with_synthetic_source(self):
    with beam.Pipeline() as pipeline:
      collection = (
          pipeline
          | beam.io.Read(
              SyntheticSource({
                  "numRecords": 10, "keySizeBytes": 1, "valueSizeBytes": 1
              }))
          | "identical keys" >> beam.Map(lambda x: (None, x[1]))
          | "Group key" >> beam.GroupIntoBatches(2)
          | "count size" >> beam.Map(lambda x: len(x[1])))
      assert_that(collection, equal_to([2, 2, 2, 2, 2]))

  def test_with_sharded_key_in_global_window(self):
    with TestPipeline() as pipeline:
      collection = (
          pipeline
          | beam.Create(GroupIntoBatchesTest._create_test_data())
          | util.GroupIntoBatches.WithShardedKey(
              GroupIntoBatchesTest.BATCH_SIZE))
      num_batches = collection | beam.combiners.Count.Globally()
      assert_that(
          num_batches,
          equal_to([
              int(
                  math.ceil(
                      GroupIntoBatchesTest.NUM_ELEMENTS /
                      GroupIntoBatchesTest.BATCH_SIZE))
          ]))

  def test_buffering_timer_in_fixed_window_streaming(self):
    window_duration = 6
    max_buffering_duration_secs = 100

    start_time = timestamp.Timestamp(0)
    test_stream = (
        TestStream().add_elements([
            TimestampedValue(value, start_time + i)
            for i, value in enumerate(GroupIntoBatchesTest._create_test_data())
        ]).advance_processing_time(150).advance_watermark_to(
            start_time + window_duration).advance_watermark_to(
                start_time + window_duration +
                1).advance_watermark_to_infinity())

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0).
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | "fixed window" >> WindowInto(FixedWindows(window_duration))
          | util.GroupIntoBatches(
              GroupIntoBatchesTest.BATCH_SIZE,
              max_buffering_duration_secs,
              fake_clock)
          | "count elements in batch" >> Map(lambda x: (None, len(x[1])))
          | GroupByKey()
          | "global window" >> WindowInto(GlobalWindows())
          | FlatMapTuple(lambda k, vs: vs))

      # Window duration is 6 and batch size is 5, so output batch size
      # should be 5 (flush because of batch size reached).
      expected_0 = 5
      # There is only one element left in the window so batch size
      # should be 1 (flush because of max buffering duration reached).
      expected_1 = 1
      # Collection has 10 elements, there are only 4 left, so batch size should
      # be 4 (flush because of end of window reached).
      expected_2 = 4
      assert_that(
          num_elements_per_batch,
          equal_to([expected_0, expected_1, expected_2]),
          "assert2")

  def test_buffering_timer_in_global_window_streaming(self):
    max_buffering_duration_secs = 42

    start_time = timestamp.Timestamp(0)
    test_stream = TestStream().advance_watermark_to(start_time)
    for i, value in enumerate(GroupIntoBatchesTest._create_test_data()):
      test_stream.add_elements(
          [TimestampedValue(value, start_time + i)]) \
        .advance_processing_time(5)
    test_stream.advance_watermark_to(
        start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \
      .advance_watermark_to_infinity()

    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
      # Set a batch size larger than the total number of elements.
      # Since we're in a global window, we would have been waiting
      # for all the elements to arrive without the buffering time limit.
      batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2

      # To trigger the processing time timer, use a fake clock with start time
      # being Timestamp(0). Since the fake clock never really advances during
      # the pipeline execution, meaning that the timer is always set to the same
      # value, the timer will be fired on every element after the first firing.
      fake_clock = FakeClock(now=start_time)

      num_elements_per_batch = (
          pipeline | test_stream
          | WindowInto(
              GlobalWindows(),
              trigger=Repeatedly(AfterCount(1)),
              accumulation_mode=trigger.AccumulationMode.DISCARDING)
          | util.GroupIntoBatches(
              batch_size, max_buffering_duration_secs, fake_clock)
          | 'count elements in batch' >> Map(lambda x: (None, len(x[1])))
          | GroupByKey()
          | FlatMapTuple(lambda k, vs: vs))

      # We will flush twice when the max buffering duration is reached and when
      # the global window ends.
      assert_that(num_elements_per_batch, equal_to([9, 1]))

  def test_output_typehints(self):
    transform = util.GroupIntoBatches.WithShardedKey(
        GroupIntoBatchesTest.BATCH_SIZE)
    unused_input_type = typehints.Tuple[str, str]
    output_type = transform.infer_output_type(unused_input_type)
    self.assertTrue(isinstance(output_type, typehints.TupleConstraint))
    k, v = output_type.tuple_types
    self.assertTrue(isinstance(k, ShardedKeyType))
    self.assertTrue(isinstance(v, typehints.IterableTypeConstraint))

    with TestPipeline() as pipeline:
      collection = (
          pipeline
          | beam.Create([((1, 2), 'a'), ((2, 3), 'b')])
          | util.GroupIntoBatches.WithShardedKey(
              GroupIntoBatchesTest.BATCH_SIZE))
      self.assertTrue(
          collection.element_type,
          typehints.Tuple[
              ShardedKeyType[typehints.Tuple[int, int]],  # type: ignore[misc]
              typehints.Iterable[str]])

  def test_runtime_type_check(self):
    options = PipelineOptions()
    options.view_as(TypeOptions).runtime_type_check = True
    with TestPipeline(options=options) as pipeline:
      collection = (
          pipeline
          | beam.Create(GroupIntoBatchesTest._create_test_data())
          | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE))
      num_batches = collection | beam.combiners.Count.Globally()
      assert_that(
          num_batches,
          equal_to([
              int(
                  math.ceil(
                      GroupIntoBatchesTest.NUM_ELEMENTS /
                      GroupIntoBatchesTest.BATCH_SIZE))
          ]))

  def _test_runner_api_round_trip(self, transform, urn):
    context = pipeline_context.PipelineContext()
    proto = transform.to_runner_api(context)
    self.assertEqual(urn, proto.urn)
    payload = (
        proto_utils.parse_Bytes(
            proto.payload, beam_runner_api_pb2.GroupIntoBatchesPayload))
    self.assertEqual(transform.params.batch_size, payload.batch_size)
    self.assertEqual(
        transform.params.max_buffering_duration_secs * 1000,
        payload.max_buffering_duration_millis)

    transform_from_proto = (
        transform.__class__.from_runner_api_parameter(None, payload, None))
    self.assertIsInstance(transform_from_proto, transform.__class__)
    self.assertEqual(transform.params, transform_from_proto.params)

  def test_runner_api(self):
    batch_size = 10
    max_buffering_duration_secs = [None, 0, 5]

    for duration in max_buffering_duration_secs:
      self._test_runner_api_round_trip(
          util.GroupIntoBatches(batch_size, duration),
          common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)
    self._test_runner_api_round_trip(
        util.GroupIntoBatches(batch_size),
        common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn)

    for duration in max_buffering_duration_secs:
      self._test_runner_api_round_trip(
          util.GroupIntoBatches.WithShardedKey(batch_size, duration),
          common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)
    self._test_runner_api_round_trip(
        util.GroupIntoBatches.WithShardedKey(batch_size),
        common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn)


class ToStringTest(unittest.TestCase):
  def test_tostring_elements(self):
    with TestPipeline() as p:
      result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element())
      assert_that(result, equal_to(["1", "1", "2", "3"]))

  def test_tostring_iterables(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create([("one", "two", "three"), ("four", "five", "six")])
          | util.ToString.Iterables())
      assert_that(result, equal_to(["one,two,three", "four,five,six"]))

  def test_tostring_iterables_with_delimeter(self):
    with TestPipeline() as p:
      data = [("one", "two", "three"), ("four", "five", "six")]
      result = (p | beam.Create(data) | util.ToString.Iterables("\t"))
      assert_that(result, equal_to(["one\ttwo\tthree", "four\tfive\tsix"]))

  def test_tostring_kvs(self):
    with TestPipeline() as p:
      result = (p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs())
      assert_that(result, equal_to(["one,1", "two,2"]))

  def test_tostring_kvs_delimeter(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs("\t"))
      assert_that(result, equal_to(["one\t1", "two\t2"]))

  def test_tostring_kvs_empty_delimeter(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create([("one", 1), ("two", 2)]) | util.ToString.Kvs(""))
      assert_that(result, equal_to(["one1", "two2"]))


class LogElementsTest(unittest.TestCase):
  @pytest.fixture(scope="function")
  def _capture_stdout_log(request, capsys):
    with TestPipeline() as p:
      result = (
          p | beam.Create([
              TimestampedValue(
                  "event",
                  datetime(2022, 10, 1, 0, 0, 0, 0,
                           tzinfo=pytz.UTC).timestamp()),
              TimestampedValue(
                  "event",
                  datetime(2022, 10, 2, 0, 0, 0, 0,
                           tzinfo=pytz.UTC).timestamp()),
          ])
          | beam.WindowInto(FixedWindows(60))
          | util.LogElements(
              prefix='prefix_',
              with_window=True,
              with_timestamp=True,
              with_pane_info=True))

    request.captured_stdout = capsys.readouterr().out
    return result

  @pytest.mark.usefixtures("_capture_stdout_log")
  def test_stdout_logs(self):
    assert self.captured_stdout == \
      ("prefix_event, timestamp='2022-10-01T00:00:00Z', "
       "window(start=2022-10-01T00:00:00Z, end=2022-10-01T00:01:00Z), "
       "pane_info=PaneInfo(first: True, last: True, timing: UNKNOWN, "
       "index: 0, nonspeculative_index: 0)\n"
       "prefix_event, timestamp='2022-10-02T00:00:00Z', "
       "window(start=2022-10-02T00:00:00Z, end=2022-10-02T00:01:00Z), "
       "pane_info=PaneInfo(first: True, last: True, timing: UNKNOWN, "
       "index: 0, nonspeculative_index: 0)\n"), \
      f'Received from stdout: {self.captured_stdout}'

  @pytest.fixture(scope="function")
  def _capture_stdout_log_without_rfc3339(request, capsys):
    with TestPipeline() as p:
      result = (
          p | beam.Create([
              TimestampedValue(
                  "event",
                  datetime(2022, 10, 1, 0, 0, 0, 0,
                           tzinfo=pytz.UTC).timestamp()),
              TimestampedValue(
                  "event",
                  datetime(2022, 10, 2, 0, 0, 0, 0,
                           tzinfo=pytz.UTC).timestamp()),
          ])
          | beam.WindowInto(FixedWindows(60))
          | util.LogElements(
              prefix='prefix_',
              with_window=True,
              with_timestamp=True,
              use_epoch_time=True))

    request.captured_stdout = capsys.readouterr().out
    return result

  @pytest.mark.usefixtures("_capture_stdout_log_without_rfc3339")
  def test_stdout_logs_without_rfc3339(self):
    assert self.captured_stdout == \
      ("prefix_event, timestamp=1664582400, "
       "window(start=1664582400, end=1664582460)\n"
       "prefix_event, timestamp=1664668800, "
       "window(start=1664668800, end=1664668860)\n"), \
      f'Received from stdout: {self.captured_stdout}'

  def test_ptransform_output(self):
    with TestPipeline() as p:
      result = (
          p
          | beam.Create(['a', 'b', 'c'])
          | util.LogElements(prefix='prefix_'))
      assert_that(result, equal_to(['a', 'b', 'c']))

  @pytest.fixture(scope="function")
  def _capture_logs(request, caplog):
    with caplog.at_level(logging.INFO):
      with TestPipeline() as p:
        _ = (
            p | "info" >> beam.Create(["element"])
            | "I" >> beam.LogElements(prefix='info_', level=logging.INFO))
        _ = (
            p | "warning" >> beam.Create(["element"])
            | "W" >> beam.LogElements(prefix='warning_', level=logging.WARNING))
        _ = (
            p | "error" >> beam.Create(["element"])
            | "E" >> beam.LogElements(prefix='error_', level=logging.ERROR))

    request.captured_log = caplog.text

  @pytest.mark.usefixtures("_capture_logs")
  def test_setting_level_uses_appropriate_log_channel(self):
    self.assertTrue(
        re.compile('INFO(.*)info_element').search(self.captured_log))
    self.assertTrue(
        re.compile('WARNING(.*)warning_element').search(self.captured_log))
    self.assertTrue(
        re.compile('ERROR(.*)error_element').search(self.captured_log))


class ReifyTest(unittest.TestCase):
  def test_timestamp(self):
    l = [
        TimestampedValue('a', 100),
        TimestampedValue('b', 200),
        TimestampedValue('c', 300)
    ]
    expected = [
        TestWindowedValue('a', 100, [GlobalWindow()]),
        TestWindowedValue('b', 200, [GlobalWindow()]),
        TestWindowedValue('c', 300, [GlobalWindow()])
    ]
    with TestPipeline() as p:
      # Map(lambda x: x) PTransform is added after Create here, because when
      # a PCollection of TimestampedValues is created with Create PTransform,
      # the timestamps are not assigned to it. Adding a Map forces the
      # PCollection to go through a DoFn so that the PCollection consists of
      # the elements with timestamps assigned to them instead of a PCollection
      # of TimestampedValue(element, timestamp).
      pc = p | beam.Create(l) | beam.Map(lambda x: x)
      reified_pc = pc | util.Reify.Timestamp()
      assert_that(reified_pc, equal_to(expected), reify_windows=True)

  def test_window(self):
    l = [
        GlobalWindows.windowed_value('a', 100),
        GlobalWindows.windowed_value('b', 200),
        GlobalWindows.windowed_value('c', 300)
    ]
    expected = [
        TestWindowedValue(('a', 100, GlobalWindow()), 100, [GlobalWindow()]),
        TestWindowedValue(('b', 200, GlobalWindow()), 200, [GlobalWindow()]),
        TestWindowedValue(('c', 300, GlobalWindow()), 300, [GlobalWindow()])
    ]
    with TestPipeline() as p:
      pc = p | beam.Create(l)
      # Map(lambda x: x) PTransform is added after Create here, because when
      # a PCollection of WindowedValues is created with Create PTransform,
      # the windows are not assigned to it. Adding a Map forces the
      # PCollection to go through a DoFn so that the PCollection consists of
      # the elements with timestamps assigned to them instead of a PCollection
      # of WindowedValue(element, timestamp, window).
      pc = pc | beam.Map(lambda x: x)
      reified_pc = pc | util.Reify.Window()
      assert_that(reified_pc, equal_to(expected), reify_windows=True)

  def test_timestamp_in_value(self):
    l = [
        TimestampedValue(('a', 1), 100),
        TimestampedValue(('b', 2), 200),
        TimestampedValue(('c', 3), 300)
    ]
    expected = [
        TestWindowedValue(('a', TimestampedValue(1, 100)),
                          100, [GlobalWindow()]),
        TestWindowedValue(('b', TimestampedValue(2, 200)),
                          200, [GlobalWindow()]),
        TestWindowedValue(('c', TimestampedValue(3, 300)),
                          300, [GlobalWindow()])
    ]
    with TestPipeline() as p:
      pc = p | beam.Create(l) | beam.Map(lambda x: x)
      reified_pc = pc | util.Reify.TimestampInValue()
      assert_that(reified_pc, equal_to(expected), reify_windows=True)

  def test_window_in_value(self):
    l = [
        GlobalWindows.windowed_value(('a', 1), 100),
        GlobalWindows.windowed_value(('b', 2), 200),
        GlobalWindows.windowed_value(('c', 3), 300)
    ]
    expected = [
        TestWindowedValue(('a', (1, 100, GlobalWindow())),
                          100, [GlobalWindow()]),
        TestWindowedValue(('b', (2, 200, GlobalWindow())),
                          200, [GlobalWindow()]),
        TestWindowedValue(('c', (3, 300, GlobalWindow())),
                          300, [GlobalWindow()])
    ]
    with TestPipeline() as p:
      # Map(lambda x: x) hack is used for the same reason here.
      # Also, this makes the typehint on Reify.WindowInValue work.
      pc = p | beam.Create(l) | beam.Map(lambda x: x)
      reified_pc = pc | util.Reify.WindowInValue()
      assert_that(reified_pc, equal_to(expected), reify_windows=True)


class RegexTest(unittest.TestCase):
  def test_find(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["aj", "xj", "yj", "zj"])
          | util.Regex.find("[xyz]"))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_find_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("[xyz]")
      result = (p | beam.Create(["aj", "xj", "yj", "zj"]) | util.Regex.find(rc))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_find_group(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["aj", "xj", "yj", "zj"])
          | util.Regex.find("([xyz])j", group=1))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_find_empty(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "b", "c", "d"])
          | util.Regex.find("[xyz]"))
      assert_that(result, equal_to([]))

  def test_find_group_name(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["aj", "xj", "yj", "zj"])
          | util.Regex.find("(?P<namedgroup>[xyz])j", group="namedgroup"))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_find_group_name_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("(?P<namedgroup>[xyz])j")
      result = (
          p | beam.Create(["aj", "xj", "yj", "zj"])
          | util.Regex.find(rc, group="namedgroup"))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_find_all_groups(self):
    data = ["abb ax abbb", "abc qwerty abcabcd xyz"]
    with TestPipeline() as p:
      pcol = (p | beam.Create(data))

      assert_that(
          pcol | 'with default values' >> util.Regex.find_all('a(b*)'),
          equal_to([['abb', 'a', 'abbb'], ['ab', 'ab', 'ab']]),
          label='CheckWithDefaultValues')

      assert_that(
          pcol | 'group 1' >> util.Regex.find_all('a(b*)', 1),
          equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]),
          label='CheckWithGroup1')

      assert_that(
          pcol | 'group 1 non empty' >> util.Regex.find_all(
              'a(b*)', 1, outputEmpty=False),
          equal_to([['b', 'b', 'b'], ['bb', 'bbb']]),
          label='CheckGroup1NonEmpty')

      assert_that(
          pcol | 'named group' >> util.Regex.find_all(
              'a(?P<namedgroup>b*)', 'namedgroup'),
          equal_to([['b', 'b', 'b'], ['bb', '', 'bbb']]),
          label='CheckNamedGroup')

      assert_that(
          pcol | 'all groups' >> util.Regex.find_all(
              'a(?P<namedgroup>b*)', util.Regex.ALL),
          equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')],
                    [('abb', 'bb'), ('a', ''), ('abbb', 'bbb')]]),
          label='CheckAllGroups')

      assert_that(
          pcol | 'all non empty groups' >> util.Regex.find_all(
              'a(b*)', util.Regex.ALL, outputEmpty=False),
          equal_to([[('ab', 'b'), ('ab', 'b'), ('ab', 'b')],
                    [('abb', 'bb'), ('abbb', 'bbb')]]),
          label='CheckAllNonEmptyGroups')

  def test_find_kv(self):
    with TestPipeline() as p:
      pcol = (p | beam.Create(['a b c d']))
      assert_that(
          pcol | 'key 1' >> util.Regex.find_kv(
              'a (b) (c)',
              1,
          ),
          equal_to([('b', 'a b c')]),
          label='CheckKey1')

      assert_that(
          pcol | 'key 1 group 1' >> util.Regex.find_kv('a (b) (c)', 1, 2),
          equal_to([('b', 'c')]),
          label='CheckKey1Group1')

  def test_find_kv_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("a (b) (c)")
      result = (p | beam.Create(["a b c"]) | util.Regex.find_kv(rc, 1, 2))
      assert_that(result, equal_to([("b", "c")]))

  def test_find_kv_none(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["x y z"])
          | util.Regex.find_kv("a (b) (c)", 1, 2))
      assert_that(result, equal_to([]))

  def test_match(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "x", "y", "z"])
          | util.Regex.matches("[xyz]"))
      assert_that(result, equal_to(["x", "y", "z"]))

    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "ax", "yby", "zzc"])
          | util.Regex.matches("[xyz]"))
      assert_that(result, equal_to(["y", "z"]))

  def test_match_entire_line(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "x", "y", "ay", "zz"])
          | util.Regex.matches("[xyz]$"))
      assert_that(result, equal_to(["x", "y"]))

  def test_match_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("[xyz]")
      result = (p | beam.Create(["a", "x", "y", "z"]) | util.Regex.matches(rc))
      assert_that(result, equal_to(["x", "y", "z"]))

  def test_match_none(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "b", "c", "d"])
          | util.Regex.matches("[xyz]"))
      assert_that(result, equal_to([]))

  def test_match_group(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
          | util.Regex.matches("x ([xyz]*)", 1))
      assert_that(result, equal_to(("xxx", "yyy", "zzz")))

  def test_match_group_name(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
          | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup'))
      assert_that(result, equal_to(("xxx", "yyy", "zzz")))

  def test_match_group_name_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("x (?P<namedgroup>[xyz]*)")
      result = (
          p | beam.Create(["a", "x xxx", "x yyy", "x zzz"])
          | util.Regex.matches(rc, 'namedgroup'))
      assert_that(result, equal_to(("xxx", "yyy", "zzz")))

  def test_match_group_empty(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a", "b", "c", "d"])
          | util.Regex.matches("x (?P<namedgroup>[xyz]*)", 'namedgroup'))
      assert_that(result, equal_to([]))

  def test_all_matched(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a x", "x x", "y y", "z z"])
          | util.Regex.all_matches("([xyz]) ([xyz])"))
      expected_result = [["x x", "x", "x"], ["y y", "y", "y"],
                         ["z z", "z", "z"]]
      assert_that(result, equal_to(expected_result))

  def test_all_matched_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("([xyz]) ([xyz])")
      result = (
          p | beam.Create(["a x", "x x", "y y", "z z"])
          | util.Regex.all_matches(rc))
      expected_result = [["x x", "x", "x"], ["y y", "y", "y"],
                         ["z z", "z", "z"]]
      assert_that(result, equal_to(expected_result))

  def test_match_group_kv(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a b c"])
          | util.Regex.matches_kv("a (b) (c)", 1, 2))
      assert_that(result, equal_to([("b", "c")]))

  def test_match_group_kv_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("a (b) (c)")
      pcol = (p | beam.Create(["a b c"]))
      assert_that(
          pcol | 'key 1' >> util.Regex.matches_kv(rc, 1),
          equal_to([("b", "a b c")]),
          label="CheckKey1")

      assert_that(
          pcol | 'key 1 group 2' >> util.Regex.matches_kv(rc, 1, 2),
          equal_to([("b", "c")]),
          label="CheckKey1Group2")

  def test_match_group_kv_none(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["x y z"])
          | util.Regex.matches_kv("a (b) (c)", 1, 2))
      assert_that(result, equal_to([]))

  def test_match_kv_group_names(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["a b c"]) | util.Regex.matches_kv(
              "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename'))
      assert_that(result, equal_to([("b", "c")]))

  def test_match_kv_group_names_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("a (?P<keyname>b) (?P<valuename>c)")
      result = (
          p | beam.Create(["a b c"])
          | util.Regex.matches_kv(rc, 'keyname', 'valuename'))
      assert_that(result, equal_to([("b", "c")]))

  def test_match_kv_group_name_none(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["x y z"]) | util.Regex.matches_kv(
              "a (?P<keyname>b) (?P<valuename>c)", 'keyname', 'valuename'))
      assert_that(result, equal_to([]))

  def test_replace_all(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["xj", "yj", "zj"])
          | util.Regex.replace_all("[xyz]", "new"))
      assert_that(result, equal_to(["newj", "newj", "newj"]))

  def test_replace_all_mixed(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["abc", "xj", "yj", "zj", "def"])
          | util.Regex.replace_all("[xyz]", 'new'))
      assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"]))

  def test_replace_all_mixed_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("[xyz]")
      result = (
          p | beam.Create(["abc", "xj", "yj", "zj", "def"])
          | util.Regex.replace_all(rc, 'new'))
      assert_that(result, equal_to(["abc", "newj", "newj", "newj", "def"]))

  def test_replace_first(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["xjx", "yjy", "zjz"])
          | util.Regex.replace_first("[xyz]", 'new'))
      assert_that(result, equal_to(["newjx", "newjy", "newjz"]))

  def test_replace_first_mixed(self):
    with TestPipeline() as p:
      result = (
          p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"])
          | util.Regex.replace_first("[xyz]", 'new'))
      assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"]))

  def test_replace_first_mixed_pattern(self):
    with TestPipeline() as p:
      rc = re.compile("[xyz]")
      result = (
          p | beam.Create(["abc", "xjx", "yjy", "zjz", "def"])
          | util.Regex.replace_first(rc, 'new'))
      assert_that(result, equal_to(["abc", "newjx", "newjy", "newjz", "def"]))

  def test_split(self):
    with TestPipeline() as p:
      data = ["The  quick   brown fox jumps over    the lazy dog"]
      result = (p | beam.Create(data) | util.Regex.split("\\W+"))
      expected_result = [[
          "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
      ]]
      assert_that(result, equal_to(expected_result))

  def test_split_pattern(self):
    with TestPipeline() as p:
      data = ["The  quick   brown fox jumps over    the lazy dog"]
      rc = re.compile("\\W+")
      result = (p | beam.Create(data) | util.Regex.split(rc))
      expected_result = [[
          "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
      ]]
      assert_that(result, equal_to(expected_result))

  def test_split_with_empty(self):
    with TestPipeline() as p:
      data = ["The  quick   brown fox jumps over    the lazy dog"]
      result = (p | beam.Create(data) | util.Regex.split("\\s", True))
      expected_result = [[
          'The',
          '',
          'quick',
          '',
          '',
          'brown',
          'fox',
          'jumps',
          'over',
          '',
          '',
          '',
          'the',
          'lazy',
          'dog'
      ]]
      assert_that(result, equal_to(expected_result))

  def test_split_without_empty(self):
    with TestPipeline() as p:
      data = ["The  quick   brown fox jumps over    the lazy dog"]
      result = (p | beam.Create(data) | util.Regex.split("\\s", False))
      expected_result = [[
          "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"
      ]]
      assert_that(result, equal_to(expected_result))


class TeeTest(unittest.TestCase):
  _side_effects: Mapping[str, int] = collections.defaultdict(int)

  def test_tee(self):
    # The imports here are to avoid issues with the class (and its attributes)
    # possibly being pickled rather than referenced.
    def cause_side_effect(element):
      importlib.import_module(__name__).TeeTest._side_effects[element] += 1

    def count_side_effects(element):
      return importlib.import_module(__name__).TeeTest._side_effects[element]

    with TestPipeline() as p:
      result = (
          p
          | beam.Create(['a', 'b', 'c'])
          | 'TeePTransform' >> beam.Tee(beam.Map(cause_side_effect))
          | 'TeeCallable' >> beam.Tee(
              lambda pcoll: pcoll | beam.Map(
                  lambda element: cause_side_effect('X' + element))))
      assert_that(result, equal_to(['a', 'b', 'c']))

    self.assertEqual(count_side_effects('a'), 1)
    self.assertEqual(count_side_effects('Xa'), 1)


class WaitOnTest(unittest.TestCase):
  def test_find(self):
    # We need shared reference that survives pickling.
    def increment_global_counter():
      try:
        value = getattr(beam, '_WAIT_ON_TEST_COUNTER', 0)
        return value
      finally:
        setattr(beam, '_WAIT_ON_TEST_COUNTER', value + 1)

    def record(tag):
      return f'Record({tag})' >> beam.Map(
          lambda x: (x[0], tag, increment_global_counter()))

    with TestPipeline() as p:
      start = p | beam.Create([(None, ), (None, )])
      x = start | record('x')
      y = start | 'WaitForX' >> util.WaitOn(x) | record('y')
      z = start | 'WaitForY' >> util.WaitOn(y) | record('z')
      result = x | 'WaitForYZ' >> util.WaitOn(y, z) | record('result')
      assert_that(x, equal_to([(None, 'x', 0), (None, 'x', 1)]), label='x')
      assert_that(y, equal_to([(None, 'y', 2), (None, 'y', 3)]), label='y')
      assert_that(z, equal_to([(None, 'z', 4), (None, 'z', 5)]), label='z')
      assert_that(
          result,
          equal_to([(None, 'result', 6), (None, 'result', 7)]),
          label='result')


class CompatCheckTest(unittest.TestCase):
  def test_is_v1_prior_to_v2(self):
    test_cases = [
        # Basic comparison cases
        ("1.0.0", "2.0.0", True),  # v1 < v2 in major
        ("2.0.0", "1.0.0", False),  # v1 > v2 in major
        ("1.1.0", "1.2.0", True),  # v1 < v2 in minor
        ("1.2.0", "1.1.0", False),  # v1 > v2 in minor
        ("1.0.1", "1.0.2", True),  # v1 < v2 in patch
        ("1.0.2", "1.0.1", False),  # v1 > v2 in patch

        # Equal versions
        ("1.0.0", "1.0.0", False),  # Identical
        ("0.0.0", "0.0.0", False),  # Both zero

        # Different lengths - shorter vs longer
        ("1.0", "1.0.0", False),  # Should be equal (1.0 = 1.0.0)
        ("1.0", "1.0.1", True),  # 1.0.0 < 1.0.1
        ("1.2", "1.2.0", False),  # Should be equal (1.2 = 1.2.0)
        ("1.2", "1.2.3", True),  # 1.2.0 < 1.2.3
        ("2", "2.0.0", False),  # Should be equal (2 = 2.0.0)
        ("2", "2.0.1", True),  # 2.0.0 < 2.0.1
        ("1", "2.0", True),  # 1.0.0 < 2.0.0

        # Different lengths - longer vs shorter
        ("1.0.0", "1.0", False),  # Should be equal
        ("1.0.1", "1.0", False),  # 1.0.1 > 1.0.0
        ("1.2.0", "1.2", False),  # Should be equal
        ("1.2.3", "1.2", False),  # 1.2.3 > 1.2.0
        ("2.0.0", "2", False),  # Should be equal
        ("2.0.1", "2", False),  # 2.0.1 > 2.0.0
        ("2.0", "1", False),  # 2.0.0 > 1.0.0

        # Mixed length comparisons
        ("1.0", "2.0.0", True),  # 1.0.0 < 2.0.0
        ("2.0", "1.0.0", False),  # 2.0.0 > 1.0.0
        ("1", "1.0.1", True),  # 1.0.0 < 1.0.1
        ("1.1", "1.0.9", False),  # 1.1.0 > 1.0.9

        # Large numbers
        ("1.9.9", "2.0.0", True),  # 1.9.9 < 2.0.0
        ("10.0.0", "9.9.9", False),  # 10.0.0 > 9.9.9
        ("1.10.0", "1.9.0", False),  # 1.10.0 > 1.9.0
        ("1.2.10", "1.2.9", False),  # 1.2.10 > 1.2.9

        # Sequential versions
        ("1.0.0", "1.0.1", True),
        ("1.0.1", "1.0.2", True),
        ("1.0.9", "1.1.0", True),
        ("1.9.9", "2.0.0", True),

        # Null/None cases
        (None, "1.0.0", False),  # v1 is None
    ]

    for v1, v2, expected in test_cases:
      self.assertEqual(
          util.is_v1_prior_to_v2(v1=v1, v2=v2),
          expected,
          msg=f"Failed {v1} < {v2} == {expected}")


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  unittest.main()
